{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fLtNzk95zuks"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "fnRDEtTRz1Ah"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IJK05W_Rz9I3"
      },
      "source": [
        "# Build Chatbot with Gradio & Llama.cpp\n",
        "\n",
        "Author: Sitam Meur\n",
        "\n",
        "*   GitHub: [github.com/sitamgithub-MSIT](https://github.com/sitamgithub-MSIT/)\n",
        "*   X: [@sitammeur](https://x.com/sitammeur)\n",
        "\n",
        "Description: Google recently released Gemma 3 QAT—the [Quantization Aware Trained (QAT) Gemma 3](https://huggingface.co/collections/google/gemma-3-qat-67ee61ccacbf2be4195c265b) checkpoints. These models maintain similar quality to half precision while using three times less memory. This notebook demonstrates creating a user-friendly chat interface for the [gemma-3-1b-it-qat](https://huggingface.co/google/gemma-3-1b-it-qat-q4_0-gguf) text model using Llama.cpp (for inference) and Gradio (for user interface).\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3]Gradio_LlamaCpp_Chatbot.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JEogAuER0MPC"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma 3 QAT model. In this case, you can use a CPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **CPU**.\n",
        "\n",
        "### Gemma Setup\n",
        "\n",
        "**Before we dive into the tutorial, let's get you set up:**\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "2. **Model Access:** Head over to the [model page](https://huggingface.co/google/gemma-3-1b-it-qat-q4_0-gguf) and accept the usage conditions.\n",
        "3. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where we'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D3DIfibe1I4W"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ox59ZaemK3j1"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ircX6fQA05Wr"
      },
      "source": [
        "### Install dependencies\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFoZpGiQEiBW"
      },
      "outputs": [],
      "source": [
        "!pip install -q huggingface_hub scikit-build-core llama-cpp-python llama-cpp-agent gradio"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gko1egPaLEyy"
      },
      "source": [
        "### Log into Hugging Face Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GiBP1bfhLFTz"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import login\n",
        "\n",
        "login(os.environ[\"HF_TOKEN\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QRO7WkLaLeW6"
      },
      "source": [
        "## Instantiate the Gemma 3 QAT model\n",
        "\n",
        "We’ll use the Hugging Face Hub to download the 4-bit (Q4_0) quantized Gemma 3 1B instruction-tuned text-only model. New quantized Gemma 3 models, created using Quantization-Aware Training (QAT), offer improved accessibility through reduced memory usage (3x less than bfloat16) without significant accuracy loss.  QAT simulates low-precision operations during training for smaller, faster models.\n",
        "\n",
        "Let's get started by downloading the model from Hugging Face Hub."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bi5cfGCP2TR1"
      },
      "source": [
        "### Loading the model from HF Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BGNdy0NsqAo6"
      },
      "outputs": [],
      "source": [
        "# Download gguf model files\n",
        "from huggingface_hub import hf_hub_download\n",
        "\n",
        "if not os.path.exists(\"./models\"):\n",
        "    os.makedirs(\"./models\")\n",
        "\n",
        "hf_hub_download(\n",
        "    repo_id=\"google/gemma-3-1b-it-qat-q4_0-gguf\",\n",
        "    filename=\"gemma-3-1b-it-q4_0.gguf\",\n",
        "    local_dir=\"./models\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fYnY8mpQ2R9w"
      },
      "source": [
        "### Prompt Formatting\n",
        "\n",
        "Gemma 3 model requires specific formatting to understand the roles of different participants in a conversation. The prompt format is as follows:\n",
        "\n",
        "```\n",
        "<bos><start_of_turn>user\n",
        "{system_prompt}\n",
        "\n",
        "{prompt}<end_of_turn>\n",
        "<start_of_turn>model\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IOPoBYD3IjFV"
      },
      "outputs": [],
      "source": [
        "# Define the prompt markers for Gemma 3\n",
        "from llama_cpp_agent.chat_history.messages import Roles\n",
        "from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers\n",
        "\n",
        "gemma_3_prompt_markers = {\n",
        "    Roles.system: PromptMarkers(\"\", \"\\n\"),  # System prompt should be included within user message\n",
        "    Roles.user: PromptMarkers(\"<start_of_turn>user\\n\", \"<end_of_turn>\\n\"),\n",
        "    Roles.assistant: PromptMarkers(\"<start_of_turn>model\\n\", \"<end_of_turn>\\n\"),\n",
        "    Roles.tool: PromptMarkers(\"\", \"\"),  # If need tool support\n",
        "}\n",
        "\n",
        "# Create the formatter\n",
        "gemma_3_formatter = MessagesFormatter(\n",
        "    pre_prompt=\"\",  # No pre-prompt\n",
        "    prompt_markers=gemma_3_prompt_markers,\n",
        "    include_sys_prompt_in_first_user_message=True,  # Include system prompt in first user message\n",
        "    default_stop_sequences=[\"<end_of_turn>\", \"<start_of_turn>\"],\n",
        "    strip_prompt=False,  # Don't strip whitespace from the prompt\n",
        "    bos_token=\"<bos>\",  # Beginning of sequence token for Gemma 3\n",
        "    eos_token=\"<eos>\",  # End of sequence token for Gemma 3\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V92UpXaE2Wff"
      },
      "source": [
        "## Chat with Gemma 3\n",
        "\n",
        "This function handles loading the model and generating responses. Streaming provides real-time generation rather than waiting for the complete response."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dB_YfAdX5JTL"
      },
      "outputs": [],
      "source": [
        "from typing import List, Tuple\n",
        "from llama_cpp import Llama\n",
        "from llama_cpp_agent import LlamaCppAgent\n",
        "from llama_cpp_agent.providers import LlamaCppPythonProvider\n",
        "from llama_cpp_agent.chat_history import BasicChatHistory\n",
        "from llama_cpp_agent.chat_history.messages import Roles\n",
        "\n",
        "llm = None\n",
        "llm_model = None\n",
        "\n",
        "def respond(\n",
        "    message: str,\n",
        "    history: List[Tuple[str, str]],\n",
        "    model: str = \"gemma-3-1b-it-q4_0.gguf\",\n",
        "    system_message: str = \"You are a helpful assistant.\",\n",
        "    max_tokens: int = 1024,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.95,\n",
        "    top_k: int = 40,\n",
        "    repeat_penalty: float = 1.1,\n",
        "):\n",
        "    \"\"\"\n",
        "    Respond to a message using the Gemma3 model via Llama.cpp.\n",
        "\n",
        "    Args:\n",
        "        - message (str): The message to respond to.\n",
        "        - history (List[Tuple[str, str]]): The chat history.\n",
        "        - model (str): The model to use.\n",
        "        - system_message (str): The system message to use.\n",
        "        - max_tokens (int): The maximum number of tokens to generate.\n",
        "        - temperature (float): The temperature of the model.\n",
        "        - top_p (float): The top-p of the model.\n",
        "        - top_k (int): The top-k of the model.\n",
        "        - repeat_penalty (float): The repetition penalty of the model.\n",
        "\n",
        "    Returns:\n",
        "        str: The response to the message.\n",
        "    \"\"\"\n",
        "    try:\n",
        "        # Load the global variables\n",
        "        global llm\n",
        "        global llm_model\n",
        "\n",
        "        # Ensure model is not None\n",
        "        if model is None:\n",
        "            model = \"gemma-3-1b-it-q4_0.gguf\"\n",
        "\n",
        "        # Load the model\n",
        "        if llm is None or llm_model != model:\n",
        "            # Check if model file exists\n",
        "            model_path = f\"models/{model}\"\n",
        "            if not os.path.exists(model_path):\n",
        "                yield f\"Error: Model file not found at {model_path}. Please check your model path.\"\n",
        "                return\n",
        "\n",
        "            llm = Llama(\n",
        "                model_path=f\"models/{model}\",\n",
        "                flash_attn=False,\n",
        "                n_gpu_layers=0,\n",
        "                n_batch=8,\n",
        "                n_ctx=2048,\n",
        "                n_threads=8,\n",
        "                n_threads_batch=8,\n",
        "            )\n",
        "            llm_model = model\n",
        "        provider = LlamaCppPythonProvider(llm)\n",
        "\n",
        "        # Create the agent\n",
        "        agent = LlamaCppAgent(\n",
        "            provider,\n",
        "            system_prompt=f\"{system_message}\",\n",
        "            custom_messages_formatter=gemma_3_formatter,\n",
        "            debug_output=True,\n",
        "        )\n",
        "\n",
        "        # Set the settings like temperature, top-k, top-p, max tokens, etc.\n",
        "        settings = provider.get_provider_default_settings()\n",
        "        settings.temperature = temperature\n",
        "        settings.top_k = top_k\n",
        "        settings.top_p = top_p\n",
        "        settings.max_tokens = max_tokens\n",
        "        settings.repeat_penalty = repeat_penalty\n",
        "        settings.stream = True\n",
        "\n",
        "        messages = BasicChatHistory()\n",
        "\n",
        "        # Add the chat history\n",
        "        for msn in history:\n",
        "            user = {\"role\": Roles.user, \"content\": msn[0]}\n",
        "            assistant = {\"role\": Roles.assistant, \"content\": msn[1]}\n",
        "            messages.add_message(user)\n",
        "            messages.add_message(assistant)\n",
        "\n",
        "        # Get the response stream\n",
        "        stream = agent.get_chat_response(\n",
        "            message,\n",
        "            llm_sampling_settings=settings,\n",
        "            chat_history=messages,\n",
        "            returns_streaming_generator=True,\n",
        "            print_output=False,\n",
        "        )\n",
        "\n",
        "        # Generate the response\n",
        "        outputs = \"\"\n",
        "        for output in stream:\n",
        "            outputs += output\n",
        "            yield outputs\n",
        "\n",
        "    # Handle exceptions that may occur during the process\n",
        "    except Exception as e:\n",
        "        raise Exception(f\"An error occurred: {str(e)}\") from e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F1PbdoSr2Y-u"
      },
      "source": [
        "## Gradio UI\n",
        "\n",
        "This offers a user interface with customizable buttons, editable messages, and options for advanced users to tune model parameters like temperature and top-p."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2vu56cxP548d"
      },
      "outputs": [],
      "source": [
        "# Create a chat interface\n",
        "import gradio as gr\n",
        "\n",
        "demo = gr.ChatInterface(\n",
        "    respond,\n",
        "    examples=[[\"What is the capital of France?\"], [\"Tell me something about artificial intelligence.\"], [\"What is gravity?\"]],\n",
        "    additional_inputs_accordion=gr.Accordion(\n",
        "        label=\"⚙️ Parameters\", open=False, render=False\n",
        "    ),\n",
        "    additional_inputs=[\n",
        "        gr.Dropdown(\n",
        "            choices=[\n",
        "                \"gemma-3-1b-it-q4_0.gguf\",\n",
        "            ],\n",
        "            value=\"gemma-3-1b-it-q4_0.gguf\",\n",
        "            label=\"Model\",\n",
        "            info=\"Select the AI model to use for chat\",\n",
        "        ),\n",
        "        gr.Textbox(\n",
        "            value=\"You are a helpful assistant.\",\n",
        "            label=\"System Prompt\",\n",
        "            info=\"Define the AI assistant's personality and behavior\",\n",
        "            lines=2,\n",
        "        ),\n",
        "        gr.Slider(\n",
        "            minimum=512,\n",
        "            maximum=2048,\n",
        "            value=1024,\n",
        "            step=1,\n",
        "            label=\"Max Tokens\",\n",
        "            info=\"Maximum length of response (higher = longer replies)\",\n",
        "        ),\n",
        "        gr.Slider(\n",
        "            minimum=0.1,\n",
        "            maximum=2.0,\n",
        "            value=0.7,\n",
        "            step=0.1,\n",
        "            label=\"Temperature\",\n",
        "            info=\"Creativity level (higher = more creative, lower = more focused)\",\n",
        "        ),\n",
        "        gr.Slider(\n",
        "            minimum=0.1,\n",
        "            maximum=1.0,\n",
        "            value=0.95,\n",
        "            step=0.05,\n",
        "            label=\"Top-p\",\n",
        "            info=\"Nucleus sampling threshold\",\n",
        "        ),\n",
        "        gr.Slider(\n",
        "            minimum=1,\n",
        "            maximum=100,\n",
        "            value=40,\n",
        "            step=1,\n",
        "            label=\"Top-k\",\n",
        "            info=\"Limit vocabulary choices to top K tokens\",\n",
        "        ),\n",
        "        gr.Slider(\n",
        "            minimum=1.0,\n",
        "            maximum=2.0,\n",
        "            value=1.1,\n",
        "            step=0.1,\n",
        "            label=\"Repetition Penalty\",\n",
        "            info=\"Penalize repeated words (higher = less repetition)\",\n",
        "        ),\n",
        "    ],\n",
        "    submit_btn=\"Send\",\n",
        "    stop_btn=\"Stop\",\n",
        "    chatbot=gr.Chatbot(scale=1, show_copy_button=True, resizable=True),\n",
        "    flagging_mode=\"never\",\n",
        "    editable=True,\n",
        "    cache_examples=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wKGE9W0ooEO3"
      },
      "source": [
        "Finally, we’ll launch our chat interface with Gradio"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kmhrakxX58av"
      },
      "outputs": [],
      "source": [
        "# Launch the chat interface\n",
        "demo.launch(debug=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gxlqWI812bJL"
      },
      "source": [
        "## What's Next?\n",
        "\n",
        "That's it! Here are some ideas to explore further:\n",
        "\n",
        "- **Explore the Gemma family models:** Visit [Gemma Open Models](https://ai.google.dev/gemma) to learn about the latest updates regarding the Gemma family models, new capabilities, versions, and more.\n",
        "\n",
        "- **Gradio Customization:** Explore the [Gradio documentation](https://www.gradio.app/docs) to learn about customizing your chat interface, adding new options and features.\n",
        "\n",
        "- **Share Your Gradio Dashboard:** Check out the [Sharing your Gradio app](https://www.gradio.app/guides/sharing-your-app) page to learn how to safely share your Gradio dashboard with others!"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "[Gemma_3]Gradio_LlamaCpp_Chatbot.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
